Machine learning

This example is available as a jupyter notebook here.

And on Google Colab here

Setup the environment if this is executed on Google Colab.

Make sure to change the runtime type to GPU. To do this go to Runtime -> Change runtime type -> GPU

Otherwise, rendering won't work in Google Colab.

import os

try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False

if IN_COLAB:
    os.system("pip install --quiet 'x_xy[all_muj] @ git+https://github.com/SimiPixel/x_xy_v2'")
    os.system("pip install --quiet mediapy")

import x_xy
# automatically detects colab or not
x_xy.utils.setup_colab_env()
from x_xy.subpkgs import ml, exp, benchmark, sys_composer, sim2real, omc
import mediapy
import jax.numpy as jnp
import tree_utils
import jax
def load_systems():
    sys = exp.load_sys("S_04", morph_yaml_key="seg2", delete_after_morph=["seg5", "imu3"])
    sys_noimu, _ = sys_composer.make_sys_noimu(sys)

    def _geoms_replace_color(sys: x_xy.System, color):
        link_idx_to_root = 0
        geoms = [g.replace(color=color) for g in sys.geoms if g.link_idx != link_idx_to_root]
        return sys.replace(geoms=geoms)

    # replace render color of geoms for render of predicted motion
    prediction_color = (78 / 255, 163 / 255, 243 / 255, 1.0)
    sys_newcolor = _geoms_replace_color(sys_noimu, prediction_color)
    sys_render = sys_composer.inject_system(sys, sys_newcolor.add_prefix_suffix("hat_"))

    return sys, sys_noimu, sys_render


def load_data_and_prediction(motion, sys, sys_noimu, params):
    exp_data = exp.load_data("S_04", motion)
    xml_str = exp.load_xml_str("S_04")
    xs = sim2real.xs_from_raw(sys, exp.link_name_pos_rot_data(exp_data, xml_str), qinv=True)

    # slightly decrease `transform1.pos.x` by a little
    translations, rotations = sim2real.unzip_xs(sys, xs)
    seg_mask = jnp.array([sys.name_to_idx(name) for name in sys.link_names[1:] if name[:3] != "imu"])
    imu_mask = jnp.array([sys.name_to_idx(name) for name in sys.link_names[1:] if name[:3] == "imu"])
    translations = translations.replace(pos=translations.pos.at[:, seg_mask, 0].set(translations.pos[:, seg_mask, 0] - 0.03))
    translations = translations.replace(pos=translations.pos.at[:, imu_mask, 0].set(translations.pos[:, imu_mask, 0] + 0.03))

    if sys.link_parents[sys.name_to_idx("seg2")] != -1:
        # a little extra for seg2
        seg_mask = jnp.array([sys.name_to_idx("seg2")])
        translations = translations.replace(pos=translations.pos.at[:, seg_mask, 0].set(translations.pos[:, seg_mask, 0] - 0.02))

    xs_translated = sim2real.zip_xs(sys, translations, rotations)

    X = {seg: {} for seg in ["seg2", "seg3", "seg4"]}
    for seg in X:
        imu_data = exp_data[seg]["imu_rigid"]
        imu_data.pop("mag")
        if seg == "seg3":
            imu_data = tree_utils.tree_zeros_like(imu_data)
        X[seg].update(imu_data)
    y = x_xy.rel_pose(sys_noimu, xs, sys)

    filter = ml.RNNOFilter(params=params)
    filter.init(sys_noimu, tree_utils.tree_slice(X, 0))
    yhat = tree_utils.tree_slice(filter.predict(tree_utils.add_batch_dim(X)), 0)
    return xs_translated, X, y, yhat


def render(sys, sys_noimu, sys_render, xs, yhat):

    xs_noimu = sim2real.match_xs(sys_noimu, xs, sys)

    # `yhat` are child-to-parent transforms, but we need parent-to-child
    # this dictonary has now all links that don't connect to worldbody
    transform2hat_rot = jax.tree_map(lambda quat: x_xy.maths.quat_inv(quat), yhat)

    transform1, transform2 = sim2real.unzip_xs(sys_noimu, xs_noimu)

    # we add the missing links in transform2hat, links that connect to worldbody
    transform2hat = []
    for i, name in enumerate(sys_noimu.link_names):
        if name in transform2hat_rot:
            transform2_name = x_xy.Transform.create(rot=transform2hat_rot[name])
        else:
            transform2_name = transform2.take(i, axis=1)
        transform2hat.append(transform2_name)

    # after transpose shape is (n_timesteps, n_links, ...)
    transform2hat = transform2hat[0].batch(*transform2hat[1:]).transpose((1, 0, 2))

    xshat = sim2real.zip_xs(sys_noimu, transform1, transform2hat)

    # swap time axis, and link axis
    xs, xshat = xs.transpose((1, 0, 2)), xshat.transpose((1, 0, 2))
    # create mapping from `name` -> Transform
    xs_dict = dict(
        zip(
            ["hat_" + name for name in sys_noimu.link_names],
            [xshat[i] for i in range(sys_noimu.num_links())],
        )
    )
    xs_dict.update(
        dict(
            zip(
                sys.link_names,
                [xs[i] for i in range(sys.num_links())],
            )
        )
    )

    xs_render = []
    for name in sys_render.link_names:
        xs_render.append(xs_dict[name])
    xs_render = xs_render[0].batch(*xs_render[1:])
    xs_render = xs_render.transpose((1, 0, 2))
    N = xs_render.shape()
    xs_render = [xs_render[t] for t in range(0, N, 4)]

    frames = x_xy.render(sys_render, xs_render, width=640, height=480, camera="c", 
                         add_cameras={-1: '<camera name="c" mode="targetbody" target="3" pos=".5 -.5 1.25"/>',})

    return frames
params = ml.load(pretrained="rr_rr_unknown")
motion = "thomas_fast"

sys, sys_noimu, sys_render = load_systems()
xs, X, y, yhat = load_data_and_prediction(motion, sys, sys_noimu, params)
frames = render(sys, sys_noimu, sys_render, xs, yhat)
Rendering frames..: 100%|██████████| 1150/1150 [00:06<00:00, 164.48it/s]

mediapy.show_video(frames, fps=25.0)